import os
import sys
import random
import numpy as np

from dataset import *
from victim_model.text_classifier import TextClassifier
from victim_model.text_predictor import TextPredictor, TextPredictorEnsembler
from attack_model import *

from config import Config
from tools.color import Color
from tools.logger import Logger
from tools.saver import Saver
from tools.device_manager import DeviceManager
from tools.utils import write_json
from tools.time_counter import TimeCounter

cf = Config()


def build_data_reader():
    if cf.dataset == 'imdb':
        Reader = IMDBDatasetReader
    elif cf.dataset == 'agnews':
        Reader = AGNewsDatasetReader
    elif cf.dataset == 'mr':
        Reader = MRDatasetReader
    else:
        raise ValueError(f'{cf.dataset} not implemented. Only support: imdb and agnews.')
    return Reader


def build_predictor(Reader):
    encoder = cf.encoder.split(',')
    token = cf.token.split(',')

    predictors = []
    for e, t in zip(encoder, token):
        reader = Reader(cf, token_type=t)
        # FUCK d_ckpt change to ckpt2 for universal attack
        saver = Saver(f'{cf.dataset}_{e}_{t}', d_ckpt='ckpt1')
        model = saver.load_last_epoch(TextClassifier, {'cf': cf, 'encoder_type': e, 'token_type': t})
        model = model.cuda()
        f = TextPredictor(model, reader)
        predictors.append(f)

    if len(predictors) == 1:
        predictor = predictors[0]
    else:
        predictor = TextPredictorEnsembler(predictors)
    return predictor


def build_attacker(predictor):
    if cf.attacker == 'hotflip':
        attacker = HotFlipAttacker(cf, predictor)
    elif cf.attacker == 'pwws':
        attacker = PWWSAttacker(cf, predictor)
    elif cf.attacker == 'genetic':
        attacker = GeneticAttack(cf, predictor)
    elif cf.attacker == 'universal':
        attacker = UniversalAttack(cf, predictor)
    elif cf.attacker == 'random':
        attacker = RandomAttacker(cf, predictor)
    elif cf.attacker == 'pmi':
        attacker = PMIAttack(cf, predictor)
    else:
        raise ValueError(f'{cf.attacker} not supported.')
    return attacker


def exist():
    with open('log/attack.txt', 'r', encoding='utf-8') as f:
        for line in f.readlines():
            line = line.replace('[', '')
            line = line.split()
            if line[0] == f'{cf.dataset}_{cf.attacker}_{cf.encoder}_{cf.token}':
                return True
    return False


def main():
    logger = Logger(cf.p_log['attack'], quiet=cf.quiet)
    print(cf)

    Reader = build_data_reader()
    predictor = build_predictor(Reader)
    attacker = build_attacker(predictor)

    logger.print(Color.green('reading data...'))
    reader = Reader(cf, token_type='word')
    victim_dataset = reader.read_json(cf.p_split['test'])

    logger.print(Color.green('attacking...'))
    success, adv_examples, semi_adv_examples, length = [], [], [], []
    details = []
    for i, instance in enumerate(victim_dataset):
        # print(instance)

        ret = attacker(instance)
        # print(ret)
        # quit()

        success.append(ret['success'])
        semi_adv_examples.append(ret['adv_example'])
        details.append(ret)
        if ret['success']:
            length.append(ret['length'])
            adv_examples.append(ret['adv_example'])

        if not cf.quiet:
            sys.stdout.write(f'\r{i} SR: {np.average(success)}, average subsitude length: {np.average(length)}')
        # FUCK
        if len(success) >= cf.attack_dataset_size and len(length) >= cf.transfer_dataset_size:
            break

    result = f'[{cf.adv_id} {cf.model_id}] ' \
             f'SR: {np.average(success)}, average subsitude length: {np.average(length)}'

    logger.print(Color.yellow(result))
    logger.log(result)
    reader._write(cf.p_adv['adv_examples'], adv_examples)
    reader._write(cf.p_adv['semi_adv_examples'], semi_adv_examples)
    write_json(cf.p_adv['detail'], details)


if __name__ == '__main__':
    with DeviceManager(cf.device):
        # if not exist():
        main()
